import tensorflow as tf


class Model(tf.keras.Model):
    def __init__(self, mlp_units, k=32, trainable=True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.k = k
        self.mlp_list = []
        for mlp_cell in mlp_units:
            self.mlp_list.append(tf.keras.layers.Dense(mlp_cell, activation=tf.keras.activations.tanh,
                                                       trainable=trainable))
        self.last_fcn = tf.keras.layers.Dense(3, activation=None, trainable=trainable)

    def _knn(self, k, select_mask, all_pos):
        '''
        find the top k nearest neighbor of select particles in all particles
        :param k: Int. The K of knn
        :param select_mask: A bool array. The mask of select particles. So the select particles are 'all[select_mask]',
        which is [F, 3]
        :param all_pos: Float array, [N, 3]. The coordinate of all particles
        :return: Int array, [F, k]. The knn indices in all particles of each select particle. ATTENTION, a particle cannot be found in itself knn array.
        '''
        neighbor_indices = None
        return neighbor_indices

    def pred(self, inputs):  # [N, 7]
        all_particle_feature = inputs
        pos = all_particle_feature[:, :3]  # [N, 3]
        fluid_mask = all_particle_feature[:, 6] > 0  # [N]
        neighbor_indices = self._knn(self.k, fluid_mask, pos)  # [F, k]  # todo KNN
        neighbor_feature = tf.gather(all_particle_feature, neighbor_indices)  # [F, k, 7]
        fluid_feature = all_particle_feature[fluid_mask, :6]  # [F, 6]
        _tile_fluid_feature = tf.tile(tf.expand_dims(fluid_feature, 1), (1, self.k, 1))  # [F, k, 6]
        pos_relative = neighbor_feature[:, :, :3] - _tile_fluid_feature[:, :, :3]  # [F, k, 3]
        feature_mlp = tf.concat((pos_relative, fluid_feature[:, :, 3:], _tile_fluid_feature[:, :, 3:]),
                                axis=1)  # [F, k, 3+3+4]
        for mlp in self.mlp_list:
            feature_mlp = mlp(feature_mlp)
        pred = self.last_fcn(feature_mlp)  # [F, k, 3]
        pred = tf.reduce_max(pred, axis=1) + tf.reshape([0, -9.8, 0], (1, 3))  # [F, 3]
        return pred, fluid_mask

    def loss(self, pred, truth, fluid_mask):
        return tf.reduce_mean(tf.square(pred - truth[fluid_mask]))

    def train(self, inputs, outputs, optimizer):
        with tf.GradientTape() as t:
            pred, fluid_mask = self.pred(inputs)
            current_loss = self.loss(pred, outputs, fluid_mask)
            grads = t.gradient(current_loss, self.trainable_variables)
            optimizer.apply_gradients(zip(grads, self.trainable_variables))
        return current_loss
